import taichi as ti, numpy as np

from ti_rt.color.space import XYZ2lRGB

from ..utils.common import *
from ..config import HERO_SAMPLE_N

__all__ = [
    "Hero",
    "Spectrum",
    "RGB2Spectrum",
    "Spectrum2RGB",
]

Hero = ti.types.vector(HERO_SAMPLE_N, float)


@ti.data_oriented
class Spectrum:
    def __init__(self, table_path):
        self.lambdaMin = 10000
        self.lambdaMax = 0
        self.white_point = ti.Vector.field(3, dtype=float, shape=1)

        data = []
        with open(table_path) as f:
            for line in f:
                values = line.split(',', 2)
                data.append(float(values[1]))
                v0 = float(values[0])
                self.lambdaMin = min(self.lambdaMin, v0)
                self.lambdaMax = max(self.lambdaMax, v0)
        self.size = len(data)
        self.lambdaRange = (self.lambdaMax - self.lambdaMin) / (self.size - 1)
        self.power = ti.field(shape=self.size, dtype=float)
        self.power.from_numpy(np.array(data, dtype=np.float32))

    @ti.func
    def sample(self, Lambda: float) -> float:
        ret = 0.0
        if self.lambdaMin <= Lambda <= self.lambdaMax:
            offset = (Lambda - self.lambdaMin) / self.lambdaRange
            idx = int(offset)
            ret = mix(self.power[idx], self.power[idx + 1], ti.math.fract(offset))
        return ret

    @ti.kernel
    def scale(self, coff: float):
        for i in self.power:
            self.power[i] *= coff


class RGB2Spectrum:
    def __init__(self):
        with open("./res/table/spec_coef.csv") as f:
            self.res = int(f.readline())
            self.size = self.res**3 * 9
            scale = np.empty(shape=(self.res), dtype=np.float32)
            table = np.empty(shape=(self.size), dtype=np.float32)
            for i in range(self.res):
                scale[i] = f.readline()
            for i in range(self.size // 9):
                table[i * 9 : i * 9 + 9] = f.readline().split(' ', 9)

        self.table = Vec3.field(shape=(3, self.res, self.res, self.res))
        self.scale = ti.field(dtype=float, shape=self.res)
        self.table.from_numpy(table.reshape(3, self.res, self.res, self.res, 3))
        self.scale.from_numpy(scale)

    @ti.func
    def get_max_component(self, rgb: Vec3):
        index, xyz = 2, rgb
        if rgb.max() == rgb[0]:
            index = 0
            xyz = rgb.yzx
        elif rgb.max() == rgb[1]:
            index = 1
            xyz = rgb.zxy
        scale = (self.res - 1) / max(1e-5, xyz[2])
        xyz[0] *= scale
        xyz[1] *= scale
        return index, xyz[0], xyz[1], xyz[2]

    @ti.func
    def biSearch(self, x: float) -> int:
        left = 0
        last_interval = self.res - 2
        size = last_interval

        while size > 0:
            half = size >> 1
            middle = left + half + 1

            if self.scale[middle] <= x:
                left = middle
                size -= half + 1
            else:
                size = half
        return left

    @ti.func
    def fetch(self, rgb: Vec3) -> Vec3:
        index, x, y, z = self.get_max_component(ti.math.clamp(rgb, 0, 1))

        xi = min(self.res - 2, int(x))
        yi = min(self.res - 2, int(y))
        zi = min(self.res - 2, self.biSearch(z))

        x0 = x - xi
        y0 = y - yi
        z0 = (z - self.scale[zi]) / (self.scale[zi + 1] - self.scale[zi])

        offset = iVec4(index, zi, yi, xi)
        dx, dy, dz = iVec4(0, 0, 0, 1), iVec4(0, 0, 1, 0), iVec4(0, 1, 0, 0)
        v000, v001 = self.table[offset], self.table[offset + dz]
        v100, v101 = self.table[offset + dx], self.table[offset + dx + dz]
        v010, v011 = self.table[offset + dy], self.table[offset + dy + dz]
        v110, v111 = self.table[offset + dx + dy], self.table[offset + dx + dy + dz]
        # 三线性插值
        return mix(
            mix(mix(v000, v100, x0), mix(v010, v110, x0), y0),
            mix(mix(v001, v101, x0), mix(v011, v111, x0), y0),
            z0,
        )

    @ti.func
    def eval(self, coef: Vec3, Lambda: float) -> float:
        result = 0.0
        if Lambda > 300:
            x = (coef[0] * Lambda + coef[1]) * Lambda + coef[2]
            y = 1 / (x * x + 1) ** 0.5
            result = 0.5 * x * y + 0.5
            if Lambda > 800:
                result *= ti.exp(-((x - 800) ** 2) / (500**2))
        return result


class Spectrum2RGB:
    def __init__(self) -> None:
        self.lambdaMin = 10000
        self.lambdaMax = 0
        data = []
        with open("./res/table/ciexyz31_1.csv") as f:
            for line in f:
                values = line.split(',', 4)
                data.append(values[1:])
                v0 = float(values[0])
                self.lambdaMin = min(self.lambdaMin, v0)
                self.lambdaMax = max(self.lambdaMax, v0)
        self.size = len(data)
        data = np.array(data, dtype=np.float32)
        data /= data.sum(axis=0).max()

        self.sampleMin, self.sampleMax = self.lambdaMin - 00, self.lambdaMax + 000
        self.sampleStep = (self.sampleMax - self.sampleMin) / HERO_SAMPLE_N
        self.lambdaRange = (self.lambdaMax - self.lambdaMin) / (self.size - 1)
        self.sensor = Vec3.field(shape=self.size)
        self.sensor.from_numpy(data)
        self.rgb2spec = RGB2Spectrum()

    @ti.func
    def rand(self) -> float:
        return self.sampleMin + ti.random() * self.sampleStep

    @ti.func
    def sample(self, lamda: float) -> Vec3:
        ret = Vec3(0)
        if self.lambdaMin <= lamda <= self.lambdaMax:
            offset = (lamda - self.lambdaMin) / self.lambdaRange
            idx = int(offset)
            ret = mix(self.sensor[idx], self.sensor[idx + 1], ti.math.fract(offset))
        return ret

    @ti.func
    def heroSample(self, spec: Hero, lamda: float) -> Vec3:
        xyz = Vec3(0)
        for i in ti.static(range(HERO_SAMPLE_N)):
            # Hero wavelength sampling
            xyz += self.sample(lamda + i * self.sampleStep) * spec[i]
        return XYZ2lRGB @ (xyz / HERO_SAMPLE_N * (self.sampleMax - self.sampleMin))

    @ti.func
    def heroSampleInv(self, lrgb: Vec3, lamda: float, beta_coef: float) -> Hero:
        coef = self.rgb2spec.fetch(lrgb)
        spec = Hero(0)
        for i in ti.static(range(HERO_SAMPLE_N)):
            # Hero wavelength sampling
            spec[i] = self.rgb2spec.eval(
                coef, (lamda + i * self.sampleStep) * beta_coef
            )
        return spec
